import inspect
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
import numpy as np
import random

from main import setup_seed,gen_data,save_arch,load_arch
from models import *
"""
https://github.com/wy1iu/LargeMargin_Softmax_Loss/blob/master/myexamples/cifar10/cifar_solver.prototxt
"""
class Model(nn.Module):
    def __init__(self,arch):
        super().__init__()
        self.cnn=CNN()

        self.classifier=nn.Linear(512,10,bias=False)
        self.arch=arch

    def forward(self, input,target):
        features = self.cnn(input)
        if self.arch=='soft':
            x= self.classifier(features)
        else:
            raise Exception('wrong')
        return x


def acc(model,loss_func,loader,device='cuda'):
    loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs, targets)
        loss += loss_func(outputs, targets).item()

        predicted = outputs.argmax(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    loss = loss / total
    correct /= total
    return loss,correct

def setlr(opt,lr):
    for p in opt.param_groups:
        p['lr'] = lr

def adjust_lr(it,opt):
    if it==0:
        setlr(opt,0.1)
    elif it ==5000:
        setlr(opt, 1e-2)
    elif it == 7500:
        setlr(opt, 1e-3)
    else:
        pass

def main():
    model = Model('soft')
    opt = optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=0.0005)
    device = torch.device('cuda')
    loss_func = nn.CrossEntropyLoss()
    trainloader,validloader, testloader = gen_data('../work/data',128)

    model, loss_func =model.to(device), loss_func.to(device)
    it = 0
    model.train()
    while it < 30000:

        train_loss = 0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            adjust_lr(it,opt)
            inputs, targets = inputs.to(device), targets.to(device)
            opt.zero_grad()
            outputs = model(inputs, targets)
            loss = loss_func(outputs, targets)
            loss.backward()
            opt.step()

            train_loss += loss.item()
            predicted = outputs.argmax(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            it += 1

        train_loss = train_loss / total
        correct /= total

        valid_loss, valid_acc = acc(model, loss_func,validloader , device)
        print(f'{it:5d} {train_loss:.2f} {correct:.2f} {valid_loss:.2f} {valid_acc:.2f}')

    test_loss, test_acc =acc(model, loss_func, validloader, device)

    print(f'final {test_loss:.2f} {test_acc:.2f}')

if __name__ == '__main__':
    main()
